#Graph with numerics
import pickle
import matplotlib.pyplot as plt
import numpy as np
# 读取保存的 metrics 数据
with open("mlp_cifar10_2_2_10epo_cuda.pkl", "rb") as file:
metrics = pickle.load(file)
# 定义绘制 NTK 指标的函数
def plot_ntk_metrics(ntk_metrics, title, ylabel, epochs_per_task):
plt.figure(figsize=(10, 6))
colors = plt.cm.tab10(np.linspace(0, 1, len(ntk_metrics))) # 为每个任务分配颜色
for t in range(len(ntk_metrics)):
# 提取任务 t 的 NTK 指标
ntk_metrics_task = ntk_metrics[t]
# x 轴从任务引入时的 epoch 开始
x_values = range(t * epochs_per_task, len(ntk_metrics_task) + t * epochs_per_task)
y_values = ntk_metrics_task
# 绘制任务 t 的数据曲线
plt.plot(x_values, y_values, marker='o', label=f'Task {t} Dataset', color=colors[t])
# 自定义 x 轴刻度和标签
xticks = np.arange(0, len(ntk_metrics_task) + (len(ntk_metrics) - 1) * epochs_per_task, epochs_per_task)
xticklabels = [f"Task {i}" for i in range(len(xticks))]
plt.xticks(xticks, xticklabels)
plt.xlabel('Task')
plt.ylabel(ylabel)
plt.title(title)
plt.legend()
plt.grid(True)
plt.show()
# 遍历不同的宽度并绘制图表
network = "MLP"
dataset = "CIFAR10"
increment = "2-2-10"
for width, data in metrics[network][dataset][increment].items():
if data["train_ntk_matrices"]:
train_ntk_matrices = data["train_ntk_matrices"]
test_ntk_matrices = data["test_ntk_matrices"]
task_accuracies = data["task_accuracies"]
# 计算 NTK 指标
def calculate_ntk_metrics(ntk_matrices):
norms = []
max_eigenvalues = []
min_eigenvalues = []
for task_ntk_matrices in ntk_matrices:
task_norms = []
task_max_eigenvalues = []
task_min_eigenvalues = []
for ntk_matrix in task_ntk_matrices:
task_norms.append(np.linalg.norm(ntk_matrix, ord='fro'))
eigenvalues = np.linalg.eigvalsh(ntk_matrix)
task_max_eigenvalues.append(eigenvalues[-1])
task_min_eigenvalues.append(eigenvalues[0])
norms.append(task_norms)
max_eigenvalues.append(task_max_eigenvalues)
min_eigenvalues.append(task_min_eigenvalues)
return norms, max_eigenvalues, min_eigenvalues
# 计算训练和测试的 NTK 指标
train_ntk_norms, train_ntk_max_eigenvalues, train_ntk_min_eigenvalues = calculate_ntk_metrics(train_ntk_matrices)
test_ntk_norms, test_ntk_max_eigenvalues, test_ntk_min_eigenvalues = calculate_ntk_metrics(test_ntk_matrices)
# 动态生成标题
base_title = f"Network: {network}, Dataset: {dataset}, 2init-2increment, Width: {width}"
# 绘制 NTK 范数
plot_ntk_metrics(train_ntk_norms,
title=f"{base_title}\nTrain NTK Norms for Each Task's Dataset Across Epochs",
ylabel='Train NTK Norm',
epochs_per_task=10)
plot_ntk_metrics(test_ntk_norms,
title=f"{base_title}\nTest NTK Norms for Each Task's Dataset Across Epochs",
ylabel='Test NTK Norm',
epochs_per_task=10)
# 绘制 NTK 最大特征值
plot_ntk_metrics(train_ntk_max_eigenvalues,
title=f"{base_title}\nTrain NTK Max Eigenvalues for Each Task's Dataset Across Epochs",
ylabel='Train NTK Max Eigenvalue',
epochs_per_task=10)
plot_ntk_metrics(test_ntk_max_eigenvalues,
title=f"{base_title}\nTest NTK Max Eigenvalues for Each Task's Dataset Across Epochs",
ylabel='Test NTK Max Eigenvalue',
epochs_per_task=10)
# 绘制 NTK 最小特征值
plot_ntk_metrics(train_ntk_min_eigenvalues,
title=f"{base_title}\nTrain NTK Min Eigenvalues for Each Task's Dataset Across Epochs",
ylabel='Train NTK Min Eigenvalue',
epochs_per_task=10)
plot_ntk_metrics(test_ntk_min_eigenvalues,
title=f"{base_title}\nTest NTK Min Eigenvalues for Each Task's Dataset Across Epochs",
ylabel='Test NTK Min Eigenvalue',
epochs_per_task=10)
# 绘制任务准确率
plot_ntk_metrics(task_accuracies,
title=f"{base_title}\nTest Accuracy for Each Task's Dataset Across Epochs",
ylabel='Test Accuracy',
epochs_per_task=10)